# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Union, Dict

import jax.numpy as jnp
from jax import ops

from brainpy.math.interoperability import as_jax
from brainpy.math.ndarray import Array as Array

__all__ = [
    'seg_matmul',
]


def _matmul_with_left_sparse(
    sparse: Dict,
    dense: Union[Array, jnp.ndarray]
):
    r"""Matrix multiplication with sparse matrix on the left.

    .. math::

      Y = M_{\mathrm{sparse}} @ M_{\mathrm{dense}}

    Parameters::

    sparse: dict
      The sparse matrix with shape of :math:`(N, M)`.
    dense: ArrayType
      The dense matrix with the shape of :math:`(M, K)`.

    Returns::

    matrix
      A tensor the the shape of :math:`(N, K)`.
    """
    assert dense.ndim in [1, 2], 'Dense matrix must be a one- or two-dimensional matrix.'
    values = sparse['data']
    rows, cols = sparse['index']
    shape = sparse['shape']
    if len(shape) != 2:
        raise ValueError(f'Sparse matrix must be a two-dimensional matrix. But we got {shape}')
    values = as_jax(values)
    rows = as_jax(rows)
    cols = as_jax(cols)
    dense = as_jax(dense)
    B = dense.take(cols, axis=0)
    if B.ndim == 2:
        prod = B * jnp.reshape(values, (-1, 1))
    else:
        prod = B * values
    return ops.segment_sum(prod, rows, shape[0])


def _matmul_with_right_sparse(
    dense: Union[Array, jnp.ndarray],
    sparse: Dict
):
    r"""Matrix multiplication with sparse matrix on the left.

    .. math::

      Y = M_{\mathrm{dense}} @ M_{\mathrm{sparse}}

    Parameters::

    dense: ArrayType
      The dense matrix with the shape of :math:`(N, M)`.
    sparse: dict
      The sparse matrix with shape of :math:`(M, K)`.

    Returns::

    matrix
      A tensor the the shape of :math:`(N, K)`.
    """
    assert dense.ndim in [1, 2], 'Dense matrix must be a one- or two-dimensional matrix.'
    values = sparse['data']
    rows, cols = sparse['index']
    shape = sparse['shape']
    if len(shape) != 2:
        raise ValueError(f'Sparse matrix must be a two-dimensional matrix. But we got {shape}')
    values = as_jax(values)
    rows = as_jax(rows)
    cols = as_jax(cols)
    dense = as_jax(dense)
    if dense.ndim == 2:
        A = dense[:, rows]
        prod = (A * values).T
        res = ops.segment_sum(prod, cols, shape[1]).T
    else:
        prod = dense[rows] * values
        res = ops.segment_sum(prod, cols, shape[1])
    return res


def seg_matmul(A, B):
    r"""Sparse matrix multiplication.

    .. math::

       y = A @ B

    where :math:`A` or :math:`B` is a sparse matrix.
    :math:`A` and :math:`B` cannot be both sparse.

    Examples::

    >>> import brainpy.math as bm

    1. when the left matrix :math:`A` is a sparse matrix with the shape of :math:`(N, M)`,

    >>> # A is a sparse matrix (3, 4):
    >>> #   [[0, 2, 0, 4],
    >>> #    [1, 0, 0, 0],
    >>> #    [0, 3, 0, 2]]
    >>> values = bm.asarray([2, 4, 1, 3, 2])
    >>> rows = bm.asarray([0, 0, 1, 2, 2])
    >>> cols = bm.asarray([1, 3, 0, 1, 3])
    >>> sparse = {'data': values, 'index': (rows, cols), 'shape': (3, 4)}
    >>> B = bm.arange(4)
    >>> bm.sparse.sparse_matmul(sparse, B)
    ArrayType([14,  0,  9], dtype=int32)
    >>> B = bm.random.rand(4, 3)
    >>> bm.sparse.sparse_matmul(sparse, B)
    ArrayType([[3.8331761 , 1.3708692 , 4.510223  ],
              [0.9960836 , 0.37550318, 0.7370341 ],
              [2.3700516 , 0.7574289 , 4.1124535 ]], dtype=float32)

    2. when the right matrix :math:`B` is a sparse matrix with the shape of :math:`(M, K)`,

    >>> A = bm.arange(3)
    >>> bm.sparse.sparse_matmul(A, sparse)
    ArrayType([1, 6, 0, 4], dtype=int32)
    >>> A = bm.random.rand(2, 3)
    >>> bm.sparse.sparse_matmul(A, sparse)
    ArrayType([[0.438388  , 1.4346815 , 0.        , 2.361964  ],
              [0.9171978 , 1.1214957 , 0.        , 0.90534496]],  dtype=float32)

    Parameters::

    A: tensor, sequence
      The dense or sparse matrix with the shape of :math:`(N, M)`.
    B: tensor, sequence
      The dense or sparse matrix with the shape of :math:`(M, K)`.

    Returns::

    results: ArrayType
      The tensor with the shape of :math:`(N, K)`.
    """
    if isinstance(A, dict):
        if not isinstance(B, (Array, jnp.ndarray)):
            raise ValueError('A and B cannot be both sparse. \n'
                             f'A:\n{A}\n'
                             f'B:\n{B}')
        return _matmul_with_left_sparse(A, B)
    else:
        if not isinstance(B, dict):
            raise ValueError('A and B cannot be both dense. \n'
                             f'A:\n{A}\n'
                             f'B:\n{B}')
        return _matmul_with_right_sparse(A, B)
